# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import warnings
from functools import partial

import numpy as np
from scipy.signal import spectrogram

from ..parallel import parallel_func
from ..utils import _check_option, _ensure_int, logger, verbose, warn
from ..utils.numerics import _mask_to_onsets_offsets


# adapted from SciPy
# https://github.com/scipy/scipy/blob/f71e7fad717801c4476312fe1e23f2dfbb4c9d7f/scipy/signal/_spectral_py.py#L2019  # noqa: E501
def _median_biases(n):
    # Compute the biases for 0 to max(n, 1) terms included in a median calc
    biases = np.ones(n + 1)
    # The original SciPy code is:
    #
    # def _median_bias(n):
    #     ii_2 = 2 * np.arange(1., (n - 1) // 2 + 1)
    #     return 1 + np.sum(1. / (ii_2 + 1) - 1. / ii_2)
    #
    # This is a sum over (n-1)//2 terms.
    # The ii_2 terms here for different n are:
    #
    # n=0: []  # 0 terms
    # n=1: []  # 0 terms
    # n=2: []  # 0 terms
    # n=3: [2]  # 1 term
    # n=4: [2]  # 1 term
    # n=5: [2, 4]  # 2 terms
    # n=6: [2, 4]  # 2 terms
    # ...
    #
    # We can get the terms for 0 through n using a cumulative summation and
    # indexing:
    if n >= 3:
        ii_2 = 2 * np.arange(1, (n - 1) // 2 + 1)
        sums = 1 + np.cumsum(1.0 / (ii_2 + 1) - 1.0 / ii_2)
        idx = np.arange(2, n) // 2 - 1
        biases[3:] = sums[idx]
    return biases


def _decomp_aggregate_mask(epoch, func, average, freq_sl):
    _, _, spect = func(epoch)
    spect = spect[..., freq_sl, :]
    # Do the averaging here (per epoch) to save memory
    if average == "mean":
        spect = np.nanmean(spect, axis=-1)
    elif average == "median":
        biases = _median_biases(spect.shape[-1])
        idx = (~np.isnan(spect)).sum(-1)
        spect = np.nanmedian(spect, axis=-1) / biases[idx]
    return spect


def _spect_func(epoch, func, freq_sl, average, *, output="power"):
    """Aux function."""
    # Decide if we should split this to save memory or not, since doing
    # multiple calls will incur some performance overhead. Eventually we might
    # want to write (really, go back to) our own spectrogram implementation
    # that, if possible, averages after each transform, but this will incur
    # a lot of overhead because of the many Python calls required.
    kwargs = dict(func=func, average=average, freq_sl=freq_sl)
    if epoch.nbytes > 10e6:
        spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs)
    else:
        spect = _decomp_aggregate_mask(epoch, **kwargs)
    return spect


def _check_nfft(n, n_fft, n_per_seg, n_overlap):
    """Ensure n_fft, n_per_seg and n_overlap make sense."""
    if n_per_seg is None and n_fft > n:
        raise ValueError(
            "If n_per_seg is None n_fft is not allowed to be > "
            "n_times. If you want zero-padding, you have to set "
            f"n_per_seg to relevant length. Got n_fft of {n_fft} while"
            f" signal length is {n}."
        )
    n_per_seg = n_fft if n_per_seg is None or n_per_seg > n_fft else n_per_seg
    n_per_seg = n if n_per_seg > n else n_per_seg
    if n_overlap >= n_per_seg:
        raise ValueError(
            "n_overlap cannot be greater than n_per_seg (or n_fft). Got n_overlap "
            f"of {n_overlap} while n_per_seg is {n_per_seg}."
        )
    return n_fft, n_per_seg, n_overlap


@verbose
def psd_array_welch(
    x,
    sfreq,
    fmin=0,
    fmax=np.inf,
    n_fft=256,
    n_overlap=0,
    n_per_seg=None,
    n_jobs=None,
    average="mean",
    window="hamming",
    remove_dc=True,
    *,
    output="power",
    verbose=None,
):
    """Compute power spectral density (PSD) using Welch's method.

    Welch's method is described in :footcite:t:`Welch1967`.

    Parameters
    ----------
    x : array, shape=(..., n_times)
        The data to compute PSD from.
    sfreq : float
        The sampling frequency.
    fmin : float
        The lower frequency of interest.
    fmax : float
        The upper frequency of interest.
    n_fft : int
        The length of FFT used, must be ``>= n_per_seg`` (default: 256).
        The segments will be zero-padded if ``n_fft > n_per_seg``.
    n_overlap : int
        The number of points of overlap between segments. Will be adjusted
        to be <= n_per_seg. The default value is 0.
    n_per_seg : int | None
        Length of each Welch segment (windowed with a Hamming window). Defaults
        to None, which sets n_per_seg equal to n_fft.
    %(n_jobs)s
    %(average_psd)s

        .. versionadded:: 0.19.0
    %(window_psd)s

        .. versionadded:: 0.22.0
    %(remove_dc)s

    output : str
        The format of the returned ``psds`` array, ``'complex'`` or
        ``'power'``:

        * ``'power'`` : the power spectral density is returned.
        * ``'complex'`` : the complex fourier coefficients are returned per
          window.

        .. versionadded:: 1.4.0
    %(verbose)s

    Returns
    -------
    psds : ndarray, shape (..., n_freqs) or (..., n_freqs, n_segments)
        The power spectral densities. If ``average='mean`` or
        ``average='median'``, the returned array will have the same shape
        as the input data plus an additional frequency dimension.
        If ``average=None``, the returned array will have the same shape as
        the input data plus two additional dimensions corresponding to
        frequencies and the unaggregated segments, respectively.
    freqs : ndarray, shape (n_freqs,)
        The frequencies.

    Notes
    -----
    .. versionadded:: 0.14.0

    References
    ----------
    .. footbibliography::
    """
    _check_option("average", average, (None, False, "mean", "median"))
    _check_option("output", output, ("power", "complex"))
    detrend = "constant" if remove_dc else False
    mode = "complex" if output == "complex" else "psd"
    n_fft = _ensure_int(n_fft, "n_fft")
    n_overlap = _ensure_int(n_overlap, "n_overlap")
    if n_per_seg is not None:
        n_per_seg = _ensure_int(n_per_seg, "n_per_seg")
    if average is False:
        average = None

    dshape = x.shape[:-1]
    n_times = x.shape[-1]
    x = x.reshape(-1, n_times)

    # Prep the PSD
    n_fft, n_per_seg, n_overlap = _check_nfft(n_times, n_fft, n_per_seg, n_overlap)
    win_size = n_fft / float(sfreq)
    logger.info(f"Effective window size : {win_size:0.3f} (s)")
    freqs = np.arange(n_fft // 2 + 1, dtype=float) * (sfreq / n_fft)
    freq_mask = (freqs >= fmin) & (freqs <= fmax)
    if not freq_mask.any():
        raise ValueError(f"No frequencies found between fmin={fmin} and fmax={fmax}")
    freq_sl = slice(*(np.where(freq_mask)[0][[0, -1]] + [0, 1]))
    del freq_mask
    freqs = freqs[freq_sl]

    step = max(int(n_per_seg) - int(n_overlap), 1)
    if n_times >= n_per_seg:
        n_segments = 1 + (n_times - n_per_seg) // step
        analyzed_end = step * (n_segments - 1) + n_per_seg
    else:
        n_segments = 0
        analyzed_end = 0

    nan_mask_full = np.isnan(x)
    nan_present = nan_mask_full.any()
    if nan_present:
        good_mask_full = ~nan_mask_full
        aligned_nan = np.allclose(good_mask_full, good_mask_full[[0]], equal_nan=True)
    else:
        aligned_nan = False

    if analyzed_end > 0:
        # Inf always counts as non-finite per-channel
        nonfinite_mask = np.isinf(x[..., :analyzed_end])
        # NaNs count per-channel only if NOT aligned (i.e., not annotations)
        if nan_present and not aligned_nan:
            nonfinite_mask |= nan_mask_full[..., :analyzed_end]
        bad_ch = nonfinite_mask.any(axis=-1)
    else:
        bad_ch = np.zeros(x.shape[0], dtype=bool)

    if bad_ch.any():
        warn(
            "Non-finite values (NaN/Inf) detected in some channels; PSD for "
            "those channels will be NaN.",
        )
        # avoid downstream NumPy warnings by zeroing bad channels;
        # will overwrite their PSD rows with NaN at the end
        x = x.copy()
        x[bad_ch] = 0.0

    # Parallelize across first N-1 dimensions
    logger.debug(
        f"Spectogram using {n_fft}-point FFT on {n_per_seg} samples with "
        f"{n_overlap} overlap and {window} window"
    )

    parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs)
    _func = partial(
        spectrogram,
        detrend=detrend,
        noverlap=n_overlap,
        nperseg=n_per_seg,
        nfft=n_fft,
        fs=sfreq,
        window=window,
        mode=mode,
    )
    if nan_present and aligned_nan:
        # Aligned NaNs across channels → treat as bad annotations.
        good_mask = ~nan_mask_full
        t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
        x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
        # weights reflect the number of samples used from each span. For spans longer
        # than `n_per_seg`, trailing samples may be discarded. For spans shorter than
        # `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically
        # reduces `n_per_seg` to match the span length (with a warning).
        step = n_per_seg - n_overlap
        span_lengths = [span.shape[-1] for span in x_splits]
        weights = [
            w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths
        ]
        agg_func = partial(np.average, weights=weights)
        if n_jobs > 1:
            logger.info(
                f"Data split into {len(x_splits)} (probably unequal) chunks due to "
                '"bad_*" annotations. Parallelization may be sub-optimal.'
            )
        if (np.array(span_lengths) < n_per_seg).any():
            logger.info(
                "At least one good data span is shorter than n_per_seg, and will be "
                "analyzed with a shorter window than the rest of the file."
            )

        def func(*args, **kwargs):
            # swallow SciPy warnings caused by short good data spans
            with warnings.catch_warnings():
                warnings.filterwarnings(
                    action="ignore",
                    module="scipy",
                    category=UserWarning,
                    message=r"nperseg = \d+ is greater than input length",
                )
                return _func(*args, **kwargs)

    else:
        # Either no NaNs, or NaNs are not aligned across channels.
        if nan_present and not aligned_nan:
            logger.info(
                "NaN masks are not aligned across channels; treating NaNs as "
                "per-channel contamination."
            )
        x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0]
        agg_func = np.concatenate
        func = _func
    f_spect = parallel(
        my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output)
        for d in x_splits
    )
    psds = agg_func(f_spect, axis=0)
    shape = dshape + (len(freqs),)
    if average is None:
        shape = shape + (-1,)

    if bad_ch.any():
        psds[bad_ch] = np.nan

    psds.shape = shape
    return psds, freqs
